{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision.models as models\n",
    "from utils import Dataset\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [],
   "source": [
    "# how to test if this kwinners implementation isd oing the right thing?\n",
    "# I can test it directly in a model\n",
    "# or try to implement the same class in a more simple setting\n",
    "# let's do the simple setting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "from sklearn import datasets\n",
    "iris = datasets.load_iris()\n",
    "x = torch.tensor(iris.data, dtype=torch.float)\n",
    "y = torch.tensor(iris.target, dtype=torch.long)\n",
    "x.shape, y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = Dataset(config=dict(dataset_name='MNIST', data_dir='~/nta/results'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [],
   "source": [
    "# build up a small neural network\n",
    "inputs = []\n",
    "\n",
    "def init_weights():\n",
    "    W1 = torch.randn((4,10), requires_grad=True)\n",
    "    b1 = torch.zeros(10, requires_grad=True)\n",
    "    W2 = torch.randn((10,3), requires_grad=True)\n",
    "    b2 = torch.zeros(3, requires_grad=True)\n",
    "    return [W1, b1, W2, b2]\n",
    "\n",
    "# torch cross_entropy is log softmax activation + negative log likelihood\n",
    "loss_func = F.cross_entropy\n",
    "\n",
    "# simple feedforward model\n",
    "def model(input):\n",
    "    W1, b1, W2, b2 = parameters\n",
    "    x = input @ W1 + b1\n",
    "    x = F.relu(x)\n",
    "    x = x @ W2 + b2\n",
    "    return x\n",
    "  \n",
    "# calculate accuracy\n",
    "def accuracy(out, y):\n",
    "    preds = torch.argmax(out, dim=1)\n",
    "    return (preds == y).float().mean().item()\n",
    "  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import StratifiedKFold\n",
    "cv = StratifiedKFold(n_splits=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy before training: 0.2733\n",
      "Loss: 11.21857452\n",
      "Loss: 0.41098920\n",
      "Loss: 0.25873697\n",
      "Loss: 0.20796219\n",
      "Loss: 0.18057358\n",
      "Training Accuracy after training: 0.9192\n",
      "Test Accuracy after training: 1.0000\n",
      "---------------------------\n",
      "Accuracy before training: 0.3333\n",
      "Loss: 13.61304092\n",
      "Loss: 0.34098408\n",
      "Loss: 0.24114083\n",
      "Loss: 0.18845302\n",
      "Loss: 0.15532117\n",
      "Training Accuracy after training: 0.9596\n",
      "Test Accuracy after training: 0.9216\n",
      "---------------------------\n",
      "Accuracy before training: 0.3333\n",
      "Loss: 29.98571205\n",
      "Loss: 0.25740978\n",
      "Loss: 0.18164127\n",
      "Loss: 0.15090778\n",
      "Loss: 0.13390265\n",
      "Training Accuracy after training: 0.9608\n",
      "Test Accuracy after training: 0.9792\n",
      "---------------------------\n"
     ]
    }
   ],
   "source": [
    "# train\n",
    "lr = 0.01\n",
    "epochs = 1000\n",
    "for train, test in cv.split(x, y):\n",
    "    x_train, y_train = x[train], y[train] \n",
    "    x_test, y_test = x[test], y[test] \n",
    "    parameters = init_weights()\n",
    "    print(\"Accuracy before training: {:.4f}\".format(accuracy(model(x), y)))\n",
    "    for epoch in range(epochs):\n",
    "        loss = loss_func(model(x_train), y_train)\n",
    "        if epoch % (epochs/5) == 0:\n",
    "          print(\"Loss: {:.8f}\".format(loss.item()))\n",
    "        # backpropagate\n",
    "        loss.backward()\n",
    "        with torch.no_grad():\n",
    "          for param in parameters:\n",
    "              # update weights\n",
    "              param -= lr * param.grad\n",
    "              # zero gradients\n",
    "              param.grad.zero_()\n",
    "\n",
    "    print(\"Training Accuracy after training: {:.4f}\".format(accuracy(model(x_train), y_train)))\n",
    "    print(\"Test Accuracy after training: {:.4f}\".format(accuracy(model(x_test), y_test)))\n",
    "    print(\"---------------------------\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Seems to be overfitting the model nicely. Actions:\n",
    "- Test accuracy - DONE\n",
    "- Repeat the experiment with a held out test set, still holds?  - DONE\n",
    "- Replace RELU with k-Winners - is k-Winners working? - TODO\n",
    "- Extend to larger dataset, MNIST\n",
    "- Replace RELU with a class\n",
    "- Extend to larger model, CNNs\n",
    "- Run similar tests for both RELU and k-Winners - results hold?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from torchvision import models\n",
    "\n",
    "class KWinners(nn.Module):\n",
    "\n",
    "    def __init__(self, k=10):\n",
    "        super(KWinners, self).__init__()\n",
    "\n",
    "        self.duty_cycle = None\n",
    "        self.k = 10\n",
    "        self.beta = 100\n",
    "        self.T = 1000\n",
    "        self.current_time = 0\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        # initialize duty cycle\n",
    "        if self.duty_cycle is None:\n",
    "            self.duty_cycle = torch.zeros_like(k)\n",
    "\n",
    "        # keep track of number of past iteratctions\n",
    "        if self.current_time < self.T:\n",
    "            self.current_time += 1\n",
    "\n",
    "        # calculating threshold and updating duty cycle \n",
    "        # should not be in the graph\n",
    "        tx = x.clone().detach()\n",
    "        # no need to calculate gradients\n",
    "        with torch.set_grad_enabled(False):\n",
    "            # get threshold\n",
    "            # nonzero_mask = torch.nonzero(tx) # will need for sparse weights\n",
    "            threshold = self._get_threshold(tx)\n",
    "            # calculate boosting\n",
    "            self._update_duty_cycle(mask)\n",
    "            boosting = self._calculate_boosting()\n",
    "            # get mask\n",
    "            tx *= boosting\n",
    "            mask = tx > threshold\n",
    "            \n",
    "        return x * mask\n",
    "\n",
    "    def _get_threshold(self, x):\n",
    "        \"\"\"Calculate dynamic theshold\"\"\" \n",
    "        abs_x = torch.abs(x).view(-1)\n",
    "        pos = abs_x.size()[0] - self.k\n",
    "        threshold, _ = torch.kthvalue(abs_x, pos)\n",
    "\n",
    "        return threshold\n",
    "\n",
    "    def _update_duty_cycle(self, mask):\n",
    "        \"\"\"Update duty cycle\"\"\" \n",
    "        time = min(self.T, self.current_time)\n",
    "        self.duty_cycle *= (time-1)/time\n",
    "        self.duty_cycle += mask.float() / time\n",
    "\n",
    "    def _calculate_boosting(self):\n",
    "        \"\"\"Calculate boosting according to formula on spatial pooling paper\"\"\"\n",
    "        mean_duty_cycle = torch.mean(self.duty_cycle)\n",
    "        diff_duty_cycle = self.duty_cycle - mean_duty_cycle\n",
    "        boosting = (self.beta * diff_duty_cycle).exp()\n",
    "\n",
    "        return boosting\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
